remove(list = ls())

library(tidyverse)
library(randomForest)
library(lme4)
library(stargazer)
library(caret)
library(ggridges)
library(pdp)
library(DiagrammeR)

# Loading in data (inherited from script 51)
complete_set <- read.csv(path)


# Cleaning data
complete_set <- complete_set %>%
  dplyr::select(-X)

##########################
#### Making functions ####
##########################

# Random forest function
do_forest <- function(df) {
  forest <- randomForest(as.factor(number_rating) ~ ., data = df, 
                         ntree = 10000, mtry = round(sqrt(ncol(df) - 1)), sampsize = 1000, importance = TRUE)
  return(forest)
}

# Visualize forest function
visualize_forest <- function(forest_object) {
  model_vis <- as.data.frame(importance(forest_object))
  
  model_vis <- cbind(rownames(model_vis), model_vis)
  
  colnames(model_vis) <- c("variable", "1", "2", "3", "4", "accuracy", "gini")
  
  model_vis$name <- factor(model_vis$variable, levels = model_vis$variable[order(model_vis$accuracy)])
  
  
  model_vis %>%
    arrange(desc(accuracy)) %>%
    slice(1:10) %>%
    ggplot(aes(x = as.factor(name), y = accuracy))+
    geom_point(size = 3, alpha = .3)+
    geom_segment(aes(x=variable, 
                     xend=variable, 
                     y=0, 
                     yend=max(accuracy)), 
                 linetype="dashed", 
                 size=0.1)+
    theme_bw()+
    labs(title="Dot plot", 
         subtitle="Variables sorted by mean decrease in accuracy",
         x = "",
         y = "") +  
    coord_flip()
}

# Confusion matrix function
do_confusion <- function(model, train_set, valid_set) {
  pred_train <- predict(model, train_set, type = "class")
  
  table(pred_train, train_set$number_rating)
  
  pred_valid <- predict(model, valid_set, type = "class")
  
  confusion <- table(pred_valid, valid_set$number_rating)
  
  confuse_matrix <- confusionMatrix(confusion)
  return(confuse_matrix)
}

# Machine learning
set.seed(123)

total_train <- sample(nrow(complete_set), 0.7*nrow(complete_set), replace = FALSE)

complete_train <- complete_set[total_train, ]
complete_valid <- complete_set[-total_train,]

model_complete <- do_forest(complete_train)

visualize_forest(model_complete)+
  scale_x_discrete(labels = c("Progress: reading",
                              "N (teachers)",
                              "Income SEN-students",
                              "Progress: math",
                              "Year",
                              "Student absence (%)",
                              "SEN-students (%)",
                              "Progress: writing",
                              "Student achievement",
                              "Satisfied parents (%)"))

do_confusion(model_complete, complete_train, complete_valid)

########################
#### Visualizations ####
########################
library(patchwork)

# Partial dependence plots
satisfaction <- partial(model_complete, train = complete_train, pred.var = "satisfaction", plot.engine = "ggplot2", type = "auto", prob = TRUE, 
                        probs = c(0, 1), which.class = 4, plot = TRUE, trim.outliers = TRUE)+
  theme_bw()+
  labs(x = "Satisfaction",
       y = "Pr(Outstanding)")

absence <- partial(model_complete, train = complete_train, pred.var = "absence_perc", plot.engine = "ggplot2", type = "auto", prob = TRUE, 
                   probs = c(0, 1), which.class = 4, plot = TRUE, trim.outliers = TRUE)+
  theme_bw()+
  labs(x = "% Absence",
       y = "Pr(Outstanding)")

expect <- partial(model_complete, train = complete_train, pred.var = "expected_point_score", plot.engine = "ggplot2", type = "auto", prob = TRUE, 
                  probs = c(0, 1), which.class = 4, plot = TRUE, trim.outliers = TRUE)+
  theme_bw()+
  labs(x = "Exp. point score",
       y = "Pr(Outstanding)")

writprog <- partial(model_complete, train = complete_train, pred.var = "prog_writ", plot.engine = "ggplot2", type = "auto", prob = TRUE, 
                    probs = c(0, 1), which.class = 4, plot = TRUE, trim.outliers = TRUE)+
  theme_bw()+
  labs(x = "Writing progress",
       y = "Pr(Outstanding)")

prog_math <- partial(model_complete, train = complete_train, pred.var = "prog_math", plot.engine = "ggplot2", type = "auto", prob = TRUE, 
                     probs = c(0, 1), which.class = 4, plot = TRUE, trim.outliers = TRUE)+
  theme_bw()+
  labs(x = "Math progress",
       y = "Pr(Outstanding)")

(satisfaction) /
  (expect | writprog) /
  (absence | prog_math)
